// 生成 proto
//go:generate sh -c "protoc --proto_path=. --proto_path=../../../../../third_party   --go_out=paths=source_relative:./ ./*.proto"
// 生成 proto errors
//go:generate sh -c "protoc --proto_path=. --proto_path=../../../../../third_party   --go-errors_out=paths=source_relative:./ ./*.proto"

package validate

import (
	"context"
	"gitee.com/guolianyu/kit/str"
	"github.com/go-kratos/kratos/v2/log"
	"github.com/go-kratos/kratos/v2/metadata"
	"github.com/go-kratos/kratos/v2/middleware"
	"github.com/go-kratos/kratos/v2/transport"
	"github.com/go-playground/locales/en"
	"github.com/go-playground/locales/pt_BR"
	"github.com/go-playground/locales/zh"
	ut "github.com/go-playground/universal-translator"
	govalidator "github.com/go-playground/validator/v10"
	"runtime"
	"strings"
)

var (
	clientLanguageHeaderKey = "Accept-Language"
	clientLanguageMetaKey   = "x-md-global-language"
)

type validatedKey struct{}

// HandlerFunc is recovery handler func.
type HandlerFunc func(ctx context.Context, req, err interface{}) error

type (
	Option  func(*options)
	options struct {
		validator *govalidator.Validate
		uniTrans  *ut.UniversalTranslator
		localeFn  func(ctx context.Context) string
		handler   HandlerFunc
		locale    string
	}
	CtxValidator struct {
		validator *govalidator.Validate
		uniTrans  ut.Translator
		locale    string
	}
)

func (v CtxValidator) Struct(i interface{}) (err error) {
	if err = v.validator.Struct(i); err != nil {
		errs, ok := err.(govalidator.ValidationErrors)
		if ok && len(errs) > 0 {
			meta := map[string]string{}
			for _, err := range errs {
				meta[str.ToSnake(err.Field())] = getErrFieldMessage(err, v.uniTrans) //err.Translate(unTrans)
			}
			message := getErrFieldMessage(errs[0], v.uniTrans)
			return ErrorValidateError(message).WithCause(err).WithMetadata(meta)
		}
		return ErrorValidateError(err.Error()).WithCause(err)
	}
	return nil
}

// Validator get validator from context
func (v CtxValidator) Validator() *govalidator.Validate {
	return v.validator
}

func (o *options) getTranslator(locale string) ut.Translator {
	translator, found := o.uniTrans.GetTranslator(locale)
	if found {
		return translator
	}

	findTranslator, found := o.uniTrans.FindTranslator(locale)
	if found {
		return findTranslator
	}
	return o.uniTrans.GetFallback()
}

// getLocaleFromContext 从上下文中 header获取locale
func (o *options) getLocaleFromContext(ctx context.Context) string {
	if o.localeFn != nil {
		return o.localeFn(ctx)
	}
	if tr, ok := transport.FromServerContext(ctx); ok {
		requestHeader := tr.RequestHeader()
		if md, ok := metadata.FromServerContext(ctx); ok {
			lang := md.Get(clientLanguageMetaKey)
			if len(lang) == 0 {
				if requestHeader.Get(clientLanguageHeaderKey) != "" {
					lang = requestHeader.Get(clientLanguageHeaderKey)
				} else {
					lang = o.locale
				}
			}
			return lang
		}
	}
	return o.locale
}

// Validator is a validator middleware.
func Validator(opts ...Option) middleware.Middleware {
	o := &options{
		validator: govalidator.New(),
		locale:    "en",
		uniTrans:  ut.New(en.New(), zh.New(), pt_BR.New()),
		handler: func(ctx context.Context, req, err interface{}) error {
			return ErrorValidateError("unknown validate error")
		},
	}
	for _, opt := range opts {
		if opt != nil {
			opt(o)
		}
	}

	return func(handler middleware.Handler) middleware.Handler {
		return func(ctx context.Context, req interface{}) (reply interface{}, err error) {
			defer func() {
				if rerr := recover(); rerr != nil {
					buf := make([]byte, 64<<10) //nolint:gomnd
					n := runtime.Stack(buf, false)
					buf = buf[:n]
					log.Context(ctx).Errorf("%v: %+v\n%s\n", rerr, req, buf)
					err = o.handler(ctx, req, rerr)
				}
			}()
			v := CtxValidator{
				validator: o.validator,
				locale:    o.getLocaleFromContext(ctx),
			}
			v.uniTrans = o.getTranslator(v.locale)
			if err = v.Struct(req); err != nil {
				return nil, err
			}
			NewContext(ctx, v)
			return handler(ctx, req)
		}
	}
}

// NewContext put validated into context
func NewContext(ctx context.Context, v CtxValidator) context.Context {
	return context.WithValue(ctx, validatedKey{}, v)
}

func FromContext(ctx context.Context) CtxValidator {
	d := ctx.Value(validatedKey{})
	if d != nil {
		return d.(CtxValidator)
	}
	return CtxValidator{validator: govalidator.New(), locale: "en", uniTrans: ut.New(en.New(), zh.New(), pt_BR.New()).GetFallback()}
}

func getErrFieldMessage(err govalidator.FieldError, unTrans ut.Translator) (msg string) {
	field, ferr := unTrans.T(err.Field())
	if ferr != nil {
		field = err.Field()
	}
	msg = err.Translate(unTrans)
	if field != err.Field() {
		msg = strings.Replace(err.Translate(unTrans), err.Field(), field, 1)
	}
	return
}
